{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf-densenet121.ckpt.data-00000-of-00001\n",
      "tf-densenet121.ckpt.index\n",
      "tf-densenet121.ckpt.meta\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "File ‘tf-densenet121.tar.gz’ not modified on server. Omitting download.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%%bash \n",
    "# Download model check-point and module from below repo by pudae:\n",
    "# Check if tf-slim will have densenet121 at some point\n",
    "wget -N https://github.com/pudae/tensorflow-densenet/raw/master/nets/densenet.py\n",
    "wget -N https://ikpublictutorial.blob.core.windows.net/deeplearningframeworks/tf-densenet121.tar.gz\n",
    "tar xzvf tf-densenet121.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#######################################################################################################\n",
    "# Summary\n",
    "# 1. Tensorflow Multi-GPU example using Estimator & Dataset high-APIs\n",
    "# 2. On-the-fly data-augmentation (random crop, random flip)\n",
    "# ToDo:\n",
    "# 3. Investigate tfrecord speed improvement (to match MXNet)\n",
    "# References:\n",
    "# https://www.tensorflow.org/performance/performance_guide\n",
    "# 1. https://jhui.github.io/2017/03/07/TensorFlow-Perforamnce-and-advance-topics/\n",
    "# 2. https://www.tensorflow.org/versions/master/performance/datasets_performance\n",
    "# 3. https://github.com/pudae/tensorflow-densenet\n",
    "# 4. https://stackoverflow.com/a/48096625/6772173\n",
    "# 5. https://stackoverflow.com/questions/47867748/transfer-learning-with-tf-estimator-estimator-framework\n",
    "# 6. https://github.com/BobLiu20/Classification_Nets/blob/master/tensorflow/common/average_gradients.py\n",
    "# 7. https://github.com/BobLiu20/Classification_Nets/blob/master/tensorflow/training/train_estimator.py\n",
    "#######################################################################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "MULTI_GPU = True  # TOGGLE THIS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda/envs/py35/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "import time\n",
    "import multiprocessing\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from PIL import Image\n",
    "import random\n",
    "import tensorflow as tf\n",
    "from tensorflow.python.framework import dtypes\n",
    "from tensorflow.python.framework.ops import convert_to_tensor\n",
    "from common.utils import download_data_chextxray, get_imgloc_labels, get_train_valid_test_split\n",
    "from common.utils import compute_roc_auc, get_cuda_version, get_cudnn_version, get_gpu_name\n",
    "from common.params_dense import *\n",
    "slim = tf.contrib.slim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "#https://github.com/pudae/tensorflow-densenet/raw/master/nets/densenet.py\n",
    "import densenet  # Download from https://github.com/pudae/tensorflow-densenet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OS:  linux\n",
      "Python:  3.5.4 |Anaconda custom (64-bit)| (default, Nov 20 2017, 18:44:38) \n",
      "[GCC 7.2.0]\n",
      "Numpy:  1.14.1\n",
      "Tensorflow:  1.8.0\n",
      "GPU:  ['Tesla V100-PCIE-16GB', 'Tesla V100-PCIE-16GB', 'Tesla V100-PCIE-16GB', 'Tesla V100-PCIE-16GB']\n",
      "CUDA Version 9.0.176\n",
      "CuDNN Version  7.0.5\n"
     ]
    }
   ],
   "source": [
    "print(\"OS: \", sys.platform)\n",
    "print(\"Python: \", sys.version)\n",
    "print(\"Numpy: \", np.__version__)\n",
    "print(\"Tensorflow: \", tf.__version__)\n",
    "print(\"GPU: \", get_gpu_name())\n",
    "print(get_cuda_version())\n",
    "print(\"CuDNN Version \", get_cudnn_version())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPUs:  24\n",
      "GPUs:  4\n"
     ]
    }
   ],
   "source": [
    "CPU_COUNT = multiprocessing.cpu_count()\n",
    "GPU_COUNT = len(get_gpu_name())\n",
    "print(\"CPUs: \", CPU_COUNT)\n",
    "print(\"GPUs: \", GPU_COUNT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "chestxray/images chestxray/Data_Entry_2017.csv\n"
     ]
    }
   ],
   "source": [
    "# Model-params\n",
    "IMAGENET_RGB_MEAN_CAFFE = np.array([123.68, 116.78, 103.94], dtype=np.float32)\n",
    "IMAGENET_SCALE_FACTOR_CAFFE = 0.017\n",
    "# Paths\n",
    "CSV_DEST = \"chestxray\"\n",
    "IMAGE_FOLDER = os.path.join(CSV_DEST, \"images\")\n",
    "LABEL_FILE = os.path.join(CSV_DEST, \"Data_Entry_2017.csv\")\n",
    "print(IMAGE_FOLDER, LABEL_FILE)\n",
    "CHKPOINT = 'tf-densenet121.ckpt'  # Downloaded tensorflow-checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Manually scale to multi-gpu\n",
    "if MULTI_GPU:\n",
    "    LR *= GPU_COUNT \n",
    "    BATCHSIZE *= GPU_COUNT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Please make sure to download\n",
      "https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-linux#download-and-install-azcopy\n",
      "Data already exists\n",
      "CPU times: user 625 ms, sys: 225 ms, total: 850 ms\n",
      "Wall time: 849 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Download data\n",
    "print(\"Please make sure to download\")\n",
    "print(\"https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-linux#download-and-install-azcopy\")\n",
    "download_data_chextxray(CSV_DEST)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################################################################################\n",
    "## Data Loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class XrayData():\n",
    "    \n",
    "    def __init__(self, img_dir, lbl_file, patient_ids, mode, \n",
    "                 width=WIDTH, height=HEIGHT, batch_size=BATCHSIZE, \n",
    "                 imagenet_mean=IMAGENET_RGB_MEAN_CAFFE, imagenet_scaling = IMAGENET_SCALE_FACTOR_CAFFE,\n",
    "                 buffer=10):\n",
    "\n",
    "        self.img_locs, self.labels = get_imgloc_labels(img_dir, lbl_file, patient_ids)\n",
    "        self.data_size = len(self.labels)\n",
    "        self.imagenet_mean = imagenet_mean\n",
    "        self.imagenet_scaling = imagenet_scaling\n",
    "        self.width = width\n",
    "        self.height = height\n",
    "        data = tf.data.Dataset.from_tensor_slices((self.img_locs, self.labels))\n",
    "        \n",
    "        # Processing\n",
    "        # Output as channels-last and TF model will reshape in densenet.py\n",
    "        # inputs = tf.transpose(inputs, [0, 3, 1, 2])\n",
    "        if mode == 'training':\n",
    "            # Augmentation and repeat\n",
    "            data = data.shuffle(self.data_size).repeat().apply(\n",
    "                tf.contrib.data.map_and_batch(self._parse_function_train, batch_size)).prefetch(buffer)\n",
    "        elif mode == \"validation\":\n",
    "            # Repeat\n",
    "             data = data.repeat().apply(\n",
    "                tf.contrib.data.map_and_batch(self._parse_function_inference, batch_size)).prefetch(buffer)           \n",
    "        elif mode == 'testing':\n",
    "            # No repeat, no augmentation\n",
    "            data = data.apply(\n",
    "                tf.contrib.data.map_and_batch(self._parse_function_inference, batch_size)).prefetch(buffer)\n",
    "        \n",
    "        self.data = data        \n",
    "        print(\"Loaded {} labels and {} images\".format(len(self.labels), len(self.img_locs)))\n",
    "        \n",
    "        \n",
    "    def _parse_function_train(self, filename, label):\n",
    "        img_rgb, label = self._preprocess_image_labels(filename, label)\n",
    "        # Random crop (from 264x264)\n",
    "        img_rgb = tf.random_crop(img_rgb, [self.height, self.width, 3])\n",
    "        # Random flip\n",
    "        img_rgb = tf.image.random_flip_left_right(img_rgb)\n",
    "        # Channels-first\n",
    "        img_rgb = tf.transpose(img_rgb, [2, 0, 1])\n",
    "        return img_rgb, label\n",
    "        \n",
    "        \n",
    "    def _parse_function_inference(self, filename, label):\n",
    "        img_rgb, label = self._preprocess_image_labels(filename, label)\n",
    "        # Resize to final dimensions\n",
    "        img_rgb = tf.image.resize_images(img_rgb, [self.height, self.width])\n",
    "        # Channels-first\n",
    "        img_rgb = tf.transpose(img_rgb, [2, 0, 1])\n",
    "        return img_rgb, label \n",
    "       \n",
    "    \n",
    "    def _preprocess_image_labels(self, filename, label):\n",
    "        # load and preprocess the image\n",
    "        img_decoded = tf.to_float(tf.image.decode_png(tf.read_file(filename), channels=3))\n",
    "        img_centered = tf.subtract(img_decoded, self.imagenet_mean)\n",
    "        img_rgb = img_centered * self.imagenet_scaling\n",
    "        return img_rgb, tf.cast(label, dtype=tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train:21563 valid:3080 test:6162\n"
     ]
    }
   ],
   "source": [
    "train_set, valid_set, test_set = get_train_valid_test_split(TOT_PATIENT_NUMBER)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded 87306 labels and 87306 images\n",
      "Loaded 7616 labels and 7616 images\n",
      "Loaded 17198 labels and 17198 images\n"
     ]
    }
   ],
   "source": [
    "with tf.device('/cpu:0'):\n",
    "    # Create dataset for iterator\n",
    "    train_dataset = XrayData(img_dir=IMAGE_FOLDER, lbl_file=LABEL_FILE, patient_ids=train_set,  \n",
    "                             mode='training')\n",
    "    valid_dataset = XrayData(img_dir=IMAGE_FOLDER, lbl_file=LABEL_FILE, patient_ids=valid_set,\n",
    "                             mode='validation')\n",
    "    test_dataset  = XrayData(img_dir=IMAGE_FOLDER, lbl_file=LABEL_FILE, patient_ids=test_set,\n",
    "                             mode='testing')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################################################################################\n",
    "## Helper Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def average_gradients(tower_grads):\n",
    "    average_grads = []\n",
    "    for grad_and_vars in zip(*tower_grads):\n",
    "        grads = []\n",
    "        for g, _ in grad_and_vars:\n",
    "            expanded_g = tf.expand_dims(g, 0)\n",
    "            grads.append(expanded_g)\n",
    "        grad = tf.concat(axis=0, values=grads)\n",
    "        grad = tf.reduce_mean(grad, 0)\n",
    "        v = grad_and_vars[0][1]\n",
    "        grad_and_var = (grad, v)\n",
    "        average_grads.append(grad_and_var)\n",
    "    return average_grads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_symbol(in_tensor, out_features):\n",
    "    # Import symbol\n",
    "    # is_training=True? | https://github.com/tensorflow/models/issues/3556\n",
    "    with slim.arg_scope(densenet.densenet_arg_scope(data_format=\"NCHW\")):\n",
    "        base_model, _ = densenet.densenet121(in_tensor,\n",
    "                                             num_classes=out_features,\n",
    "                                             is_training=True)\n",
    "        # Need to reshape from (?, 1, 1, 14) to (?, 14)\n",
    "        sym = tf.reshape(base_model, shape=[-1, out_features])\n",
    "    return sym"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_fn_single(features, labels, mode, params):\n",
    "    sym = get_symbol(features, out_features=params[\"n_classes\"])\n",
    "    # Predictions\n",
    "    predictions = tf.sigmoid(sym)\n",
    "    # ModeKeys.PREDICT\n",
    "    if mode == tf.estimator.ModeKeys.PREDICT:\n",
    "        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)\n",
    "    # Optimizer & Loss\n",
    "    optimizer = tf.train.AdamOptimizer(params['lr'], beta1=0.9, beta2=0.999)\n",
    "    loss_fn = tf.losses.sigmoid_cross_entropy(labels, sym)\n",
    "    loss = tf.reduce_mean(loss_fn)\n",
    "    train_op = optimizer.minimize(loss, tf.train.get_global_step())\n",
    "    # Create eval metric ops\n",
    "    eval_metric_ops = {\"val_loss\": slim.metrics.streaming_mean(\n",
    "        tf.losses.sigmoid_cross_entropy(labels, predictions))}\n",
    "\n",
    "    return tf.estimator.EstimatorSpec(\n",
    "        mode=mode,\n",
    "        loss=loss,\n",
    "        train_op=train_op,\n",
    "        eval_metric_ops=eval_metric_ops)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def multi_gpu_X_y_split(features, labels, batchsize, gpus):\n",
    "    # Make sure splits sum to batch-size\n",
    "    split_size = batchsize // len(gpus)\n",
    "    splits = [split_size, ] * (len(gpus) - 1)\n",
    "    splits.append(batchsize - split_size * (len(gpus) - 1))\n",
    "    # Split the features and labels\n",
    "    features_split = tf.split(features, splits, axis=0)\n",
    "    labels_split = tf.split(labels, splits, axis=0)\n",
    "    return features_split, labels_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_fn_multigpu(features, labels, mode, params):\n",
    "    if mode == tf.estimator.ModeKeys.PREDICT:\n",
    "        # Create symbol\n",
    "        sym = get_symbol(features, out_features=params[\"n_classes\"])\n",
    "        # Predictions\n",
    "        predictions = tf.sigmoid(sym)   \n",
    "        # ModeKeys.PREDICT\n",
    "        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)\n",
    "    \n",
    "    # For multi-gpu split features and labels\n",
    "    features_split, labels_split = multi_gpu_X_y_split( features, labels, params[\"batchsize\"], params[\"gpus\"])\n",
    "    tower_grads = []\n",
    "    eval_logits = []\n",
    "    # Training operation\n",
    "    global_step = tf.train.get_global_step()\n",
    "    optimizer = tf.train.AdamOptimizer(LR, beta1=0.9, beta2=0.999)\n",
    "    # Load model on multiple GPUs\n",
    "    with tf.variable_scope(tf.get_variable_scope()):\n",
    "        for i in range(len(params['gpus'])):\n",
    "            with tf.device('/gpu:%d' % i), tf.name_scope('%s_%d' % (\"classification\", i)) as scope:\n",
    "                # Symbol\n",
    "                sym = get_symbol(features_split[i], out_features=params[\"n_classes\"])\n",
    "                # Loss\n",
    "                tf.losses.sigmoid_cross_entropy(labels_split[i], sym)\n",
    "                # Training-ops\n",
    "                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope)\n",
    "                updates_op = tf.group(*update_ops)\n",
    "                with tf.control_dependencies([updates_op]):\n",
    "                    losses = tf.get_collection(tf.GraphKeys.LOSSES, scope)\n",
    "                    total_loss = tf.add_n(losses, name='total_loss')\n",
    "                # reuse var\n",
    "                tf.get_variable_scope().reuse_variables()\n",
    "                # grad compute\n",
    "                grads = optimizer.compute_gradients(total_loss)\n",
    "                tower_grads.append(grads)\n",
    "                eval_logits.append(sym)\n",
    "\n",
    "    # We must calculate the mean of each gradient\n",
    "    grads = average_gradients(tower_grads)\n",
    "    # Apply the gradients to adjust the shared variables.\n",
    "    apply_gradient_op = optimizer.apply_gradients(grads, global_step=global_step)\n",
    "    # Group all updates to into a single train op.\n",
    "    train_op = tf.group(apply_gradient_op)\n",
    "    # Create eval metric ops (predict on multi-gpu)\n",
    "    predictions =  tf.concat(eval_logits, 0)\n",
    "    eval_metric_ops = {\"val_loss\": slim.metrics.streaming_mean(\n",
    "        tf.losses.sigmoid_cross_entropy(labels, predictions))}\n",
    "\n",
    "    return tf.estimator.EstimatorSpec(\n",
    "        mode=mode,\n",
    "        loss=total_loss,\n",
    "        train_op=train_op,\n",
    "        eval_metric_ops=eval_metric_ops)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_input_fn():\n",
    "    return train_dataset.data.make_one_shot_iterator().get_next()\n",
    "def valid_input_fn():\n",
    "    return valid_dataset.data.make_one_shot_iterator().get_next()\n",
    "def test_input_fn():\n",
    "    return test_dataset.data.make_one_shot_iterator().get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Warm start from saved checkpoint (not logits)\n",
    "ws = tf.estimator.WarmStartSettings(ckpt_to_initialize_from=CHKPOINT, vars_to_warm_start=\"^(?!.*(logits))\")\n",
    "# Params\n",
    "params={\"lr\":LR, \"n_classes\":CLASSES, \"batchsize\":BATCHSIZE, \"gpus\":list(range(GPU_COUNT))}\n",
    "# Model functions\n",
    "if MULTI_GPU:\n",
    "    model_fn=model_fn_multigpu\n",
    "else:\n",
    "    model_fn=model_fn_single"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using default config.\n",
      "WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpcfbb3x5w\n",
      "INFO:tensorflow:Using config: {'_save_checkpoints_secs': 600, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_num_ps_replicas': 0, '_model_dir': '/tmp/tmpcfbb3x5w', '_log_step_count_steps': 100, '_master': '', '_is_chief': True, '_keep_checkpoint_every_n_hours': 10000, '_keep_checkpoint_max': 5, '_session_config': None, '_global_id_in_cluster': 0, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f79611c6ef0>, '_task_id': 0, '_tf_random_seed': None, '_num_worker_replicas': 1, '_train_distribute': None, '_task_type': 'worker', '_evaluation_master': '', '_service': None}\n",
      "CPU times: user 5.5 ms, sys: 162 µs, total: 5.66 ms\n",
      "Wall time: 5.03 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Create Estimator\n",
    "nn = tf.estimator.Estimator(model_fn=model_fn, params=params, warm_start_from=ws)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 99 µs, sys: 11 µs, total: 110 µs\n",
      "Wall time: 115 µs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Create train & eval specs\n",
    "train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn,\n",
    "                                    max_steps=EPOCHS*(train_dataset.data_size//BATCHSIZE))\n",
    "# Hard to run validation every epoch so playing around with throttle_secs to get 5 runs\n",
    "eval_spec = tf.estimator.EvalSpec(input_fn=valid_input_fn,\n",
    "                                  throttle_secs=400)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Running training and evaluation locally (non-distributed).\n",
      "Trimmed log ...",
      "CPU times: user 1h 57min 47s, sys: 9min 51s, total: 2h 7min 38s\n",
      "Wall time: 22min 11s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# 1 GPU - Main training loop: 33min 29s\n",
    "# 4 GPU - Main training loop: 22min 11s\n",
    "# Run train and evaluate (on validation data)\n",
    "tf.estimator.train_and_evaluate(nn, train_spec, eval_spec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /tmp/tmpcfbb3x5w/model.ckpt-1705\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "Full AUC [0.8113181918227069, 0.8603255067467157, 0.7916277297097226, 0.8790063417235601, 0.8791099267683136, 0.9115475111698327, 0.7023541384026243, 0.858907165510939, 0.6190543626664032, 0.8471454439175482, 0.7650787561316246, 0.7907428147497167, 0.7714048056837656, 0.8798529046882989]\n",
      "Test AUC: 0.8120\n",
      "CPU times: user 2min 58s, sys: 12.9 s, total: 3min 11s\n",
      "Wall time: 32.1 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# 1 GPU AUC: 0.8009\n",
    "# 4 GPU AUC: 0.8120\n",
    "predictions = list(nn.predict(test_input_fn))\n",
    "y_truth = test_dataset.labels\n",
    "y_guess = np.array(predictions)\n",
    "print(\"Test AUC: {0:.4f}\".format(compute_roc_auc(y_truth, y_guess, CLASSES))) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################################################################################\n",
    "## Synthetic Data (Pure Training)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test on fake-data -> no IO lag\n",
    "batch_in_epoch = train_dataset.data_size//BATCHSIZE\n",
    "tot_num = batch_in_epoch * BATCHSIZE\n",
    "fake_X = np.random.rand(tot_num, 3, 224, 224).astype(np.float32)\n",
    "fake_y = np.random.rand(tot_num, CLASSES).astype(np.float32) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using default config.\n",
      "WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmptsc3de0g\n",
      "INFO:tensorflow:Using config: {'_save_checkpoints_secs': 600, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_num_ps_replicas': 0, '_model_dir': '/tmp/tmptsc3de0g', '_log_step_count_steps': 100, '_master': '', '_is_chief': True, '_keep_checkpoint_every_n_hours': 10000, '_keep_checkpoint_max': 5, '_session_config': None, '_global_id_in_cluster': 0, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f796293a160>, '_task_id': 0, '_tf_random_seed': None, '_num_worker_replicas': 1, '_train_distribute': None, '_task_type': 'worker', '_evaluation_master': '', '_service': None}\n",
      "CPU times: user 4.85 ms, sys: 54 µs, total: 4.91 ms\n",
      "Wall time: 4.14 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Create Estimator\n",
    "nn = tf.estimator.Estimator(model_fn=model_fn, params=params)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 1 into /tmp/tmptsc3de0g/model.ckpt.\n",
      "INFO:tensorflow:loss = 0.76794666, step = 0\n",
      "INFO:tensorflow:global_step/sec: 1.9302\n",
      "INFO:tensorflow:loss = 0.6935522, step = 100 (51.812 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.44545\n",
      "INFO:tensorflow:loss = 0.69338816, step = 200 (40.893 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.39242\n",
      "INFO:tensorflow:loss = 0.69193834, step = 300 (41.798 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.40031\n",
      "INFO:tensorflow:loss = 0.6942548, step = 400 (41.661 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.3377\n",
      "INFO:tensorflow:loss = 0.6941431, step = 500 (42.778 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.34509\n",
      "INFO:tensorflow:loss = 0.6940154, step = 600 (42.642 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.31676\n",
      "INFO:tensorflow:loss = 0.6940793, step = 700 (43.164 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.33261\n",
      "INFO:tensorflow:loss = 0.69123954, step = 800 (42.870 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.32733\n",
      "INFO:tensorflow:loss = 0.69524056, step = 900 (42.969 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.30784\n",
      "INFO:tensorflow:loss = 0.69289315, step = 1000 (43.329 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.33822\n",
      "INFO:tensorflow:loss = 0.69344753, step = 1100 (42.768 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.27745\n",
      "INFO:tensorflow:loss = 0.6898188, step = 1200 (43.908 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.26344\n",
      "INFO:tensorflow:loss = 0.691602, step = 1300 (44.183 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 1373 into /tmp/tmptsc3de0g/model.ckpt.\n",
      "INFO:tensorflow:global_step/sec: 2.10704\n",
      "INFO:tensorflow:loss = 0.6930033, step = 1400 (47.457 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.28343\n",
      "INFO:tensorflow:loss = 0.69113714, step = 1500 (43.794 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.2884\n",
      "INFO:tensorflow:loss = 0.6936152, step = 1600 (43.699 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.25189\n",
      "INFO:tensorflow:loss = 0.6927664, step = 1700 (44.407 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 1705 into /tmp/tmptsc3de0g/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 0.69478613.\n",
      "CPU times: user 32min 54s, sys: 12min 23s, total: 45min 18s\n",
      "Wall time: 13min 55s\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.estimator.estimator.Estimator at 0x7f796293a128>"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "# 1 GPU - Synthetic data: 25min 25s \n",
    "# 4 GPU - Synthetic data: 13min 55s\n",
    "nn.train(tf.estimator.inputs.numpy_input_fn(\n",
    "    fake_X,\n",
    "    fake_y,\n",
    "    shuffle=False,\n",
    "    num_epochs=EPOCHS,\n",
    "    batch_size=BATCHSIZE))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:py35]",
   "language": "python",
   "name": "conda-env-py35-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
